{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 🔍 Predicting Item Prices from Descriptions (Part 6)\n",
    "---\n",
    "- Data Curation & Preprocessing\n",
    "- Model Benchmarking – Traditional ML vs LLMs\n",
    "- E5 Embeddings & RAG\n",
    "- Fine-Tuning GPT-4o Mini\n",
    "- Evaluating LLaMA 3.1 8B Quantized\n",
    "- ➡️ Fine-Tuning LLaMA 3.1 with QLoRA\n",
    "- Evaluating Fine-Tuned LLaMA\n",
    "- Summary & Leaderboard\n",
    "\n",
    "---\n",
    "\n",
    "# ⚙️ Part 6: Fine-Tuning LLaMA 3.1 with QLoRA\n",
    "\n",
    "- 🧑‍💻 Skill Level: Advanced\n",
    "- ⚙️ Hardware: ⚠️ GPU required - use Google Colab (A100)\n",
    "- 🛠️ Requirements: 🔑 HF Token, wandb API Key ([Weights & Biases](https://wandb.ai))\n",
    "- Tasks:\n",
    "    - Load and split dataset (Train/validation); set up [Weights & Biases](https://wandb.ai) logging\n",
    "    - Load quantized LLaMA 3.1 8B and tokenizer\n",
    "    - Prepare data with a collator for fine-tuning\n",
    "    - Configure QLoRA (LoRAConfig), training settings (SFTConfig), and tune key hyperparameters\n",
    "    - Fine-tune and push best model to Hugging Face Hub\n",
    "\n",
    "⚠️ I attempted to fine-tune the model on the full 400K dataset using an A100 on Google Colab, but it consistently crashed. So for now, I’m training on a 20K subset to understand the process, play with hyperparameters, track progress in Weights & Biases, and push the best checkpoint to the Hub.\n",
    "\n",
    "⏱️ Training on 20,000 examples took over 2 hours.\n",
    "\n",
    "The full model fine-tuned on the complete 400K dataset is available thanks to our instructor, [Ed](https://www.linkedin.com/in/eddonner) — much appreciated!  \n",
    "We’ll dive into that model in the next notebook — **stay tuned** 😉\n",
    "\n",
    "---\n",
    "📢 Find more LLM notebooks on my [GitHub repository](https://github.com/lisekarimi/lexo)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "MDyR63OTNUJ6",
    "outputId": "525372ce-f614-44f1-b894-80e289958197"
   },
   "outputs": [],
   "source": [
    "# Install required packages in Google Colab\n",
    "%pip install -q datasets transformers torch peft bitsandbytes trl accelerate wandb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "-yikV8pRBer9"
   },
   "outputs": [],
   "source": [
    "# imports\n",
    "\n",
    "import os\n",
    "import torch\n",
    "import wandb\n",
    "from google.colab import userdata\n",
    "from datetime import datetime\n",
    "from datasets import load_dataset\n",
    "from huggingface_hub import login\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, EarlyStoppingCallback\n",
    "from peft import LoraConfig\n",
    "from trl import SFTTrainer, SFTConfig, DataCollatorForCompletionOnlyLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Google Colab User Data\n",
    "# Ensure you have set the following in your Google Colab environment:\n",
    "hf_token = userdata.get('HF_TOKEN')\n",
    "login(hf_token, add_to_git_credential=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "B48QsPsvUs_x"
   },
   "source": [
    "## 🔀 Load Dataset from HF and Split into Train/Validation"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# #If you face NotImplementedError: Loading a dataset cached in a LocalFileSystem is not supported run:\n",
    "# %pip install -U datasets (for Google Colab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 177,
     "referenced_widgets": [
      "6f1f8dca2a334818a36fae380818001e",
      "6d3be1ece4a949d3b8d3736db02bcb5c",
      "c8c6bbacfe254c539f4acda8cdd5c04d",
      "db87c136ff15430892aa75fa47521b0c",
      "1d56af1140034021b2aecc5df846e499",
      "6238783102084e0c99626bf948ff5bb6",
      "f523b67e652049f7b13131d2750325bb",
      "f03cc2cf18c140c8b4a076ab99ac86e3",
      "472bb957b0e149df8ef0c26c3a3ffc19",
      "86dfcc161f2d41a7a33041848766d091",
      "6a7ed9e79ebb4f9c9962d08c78b424ca",
      "efc4817d5f734852a844640ebe7eceed",
      "0b473a8e944c4b028f51f53f62b72deb",
      "1fd89859568440f58f3ab56f32183dd4",
      "2e4bd8853acc4faa92e461210df2c689",
      "3fb588f271db4b7abb9a3631582cc7d6",
      "8f9c00ca63ca47e9873ec2a743fa1512",
      "afdae504b36845b9a98874cced112721",
      "8afd0ddfdeca43b59207a8b35a35e13c",
      "0be7a6fdb206420d88b2b2e45a37432c",
      "00f0983c1d204862b589011100297ffe",
      "8c7de85bcec742ec85f1e8b854351056",
      "5847c75b6dd74bc1b13116d91431ccf2",
      "bcb0ad86493f45848895c02c0b9deaf6",
      "18d70754531248b1ab22e1fd0df061ae",
      "028d806f909f42e2b6a7ec630f6e3cb5",
      "ff00d3192c734b398f779c7fffde57c8",
      "55388dcb89f84c7ebe7f5f7051f2d98b",
      "d3cab2b162a740fb82f78f030ea32b45",
      "cea0149336be4c92952bacb8aa820926",
      "6b560f8a028c4ba39896fd97f48f18ad",
      "2a3ed922dab44648b6d6ed63e21c549d",
      "885e1f4b9c3d45d5acd8d0a368ca557d",
      "73e42dca7c4b455f8be4b34236e6ced2",
      "c36aec28025e4baab8a3c4a293297f15",
      "7569e26e1e2b46e4a7018e1bd2bc92d5",
      "9f5795d223e74f1e8e49709ec1e4ddf1",
      "5638ccb893164fc79980eb48d06909f9",
      "70a528a0a08e4931b845ecc0992e07d6",
      "669bbecd55804849bff5a850438d905d",
      "245de1eaef2840b69e6c82afee68b4dc",
      "ad57405b8f474c0aa92833f83dde70e8",
      "cb3391329a7f4d0b93f5efffb9b0dcfe",
      "cb0007dffa284be8aff41efacdfc31cb",
      "c7de048747a24f9a9ce85396b87b8250",
      "066b3f278ec24b299504cea66b3c3e63",
      "0e1069c5bf644531902c51283a6d68e1",
      "06bd7477f9fe45d0ad4138fc21bd29dc",
      "adb68e7a8bea4b77b960e412c67a6286",
      "39ec099d38f04f4e8ea334d0c5335e2f",
      "044bf34d53024427801e24fbca808dc1",
      "e3d2839112ff4b7f9ab5bc04900ff522",
      "f620e7774fa04ed0a88d2f78d2243906",
      "7a12c0d7b32b445f978809c9aee2c62d",
      "5a230441445746d59ea8a10a4d5bb467"
     ]
    },
    "id": "XEE1FrSIh-EF",
    "outputId": "8cd19745-2f6f-41e0-96dd-5a2f72ac3a63"
   },
   "outputs": [],
   "source": [
    "HF_USER = \"lisekarimi\" # your HF name here!\n",
    "\n",
    "DATASET_NAME = f\"{HF_USER}/pricer-data\"\n",
    "dataset = load_dataset(DATASET_NAME)\n",
    "train = dataset['train']\n",
    "test = dataset['test']\n",
    "split_ratio = 0.1  # 10% for validation\n",
    "\n",
    "##############################################################################\n",
    "# Optional: limit training dataset to TRAIN_SIZE for testing/debugging\n",
    "# Comment the two lines below to use the full dataset\n",
    "TRAIN_SIZE = 20000\n",
    "train = train.select(range(TRAIN_SIZE))\n",
    "##############################################################################\n",
    "\n",
    "total_size = len(train)\n",
    "val_size = int(total_size * split_ratio)\n",
    "\n",
    "val_data = train.select(range(val_size))\n",
    "train_data = train.select(range(val_size, total_size))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "lUPNqb2Bse21",
    "outputId": "a3d09c8f-ce5a-46b0-e1b0-b4471a659f69"
   },
   "outputs": [],
   "source": [
    "print(f\"Train data size     : {len(train_data)}\")\n",
    "print(f\"Validation data size: {len(val_data)}\")\n",
    "print(f\"Test data size      : {len(test)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "wixbM-VeVfsR"
   },
   "source": [
    "## 🛠️ Hugging Face Configuration"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 35
    },
    "id": "OixVUG06VmZk",
    "outputId": "3cb523e0-fd03-4a18-913b-c22fa90e3bdd"
   },
   "outputs": [],
   "source": [
    "PROJECT_NAME = \"llama3-pricer\"\n",
    "\n",
    "# Run name for saving the model in the hub\n",
    "\n",
    "RUN_NAME =  f\"{datetime.now():%Y-%m-%d_%H.%M.%S}-size{total_size}\"\n",
    "PROJECT_RUN_NAME = f\"{PROJECT_NAME}-{RUN_NAME}\"\n",
    "HUB_MODEL_NAME = f\"{HF_USER}/{PROJECT_RUN_NAME}\"\n",
    "HUB_MODEL_NAME"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1-t1nGgnVTU4"
   },
   "source": [
    "## 🛠️ wandb Configuration"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load from Colab's secure storage\n",
    "wandb_api_key = userdata.get('WANDB_API_KEY')\n",
    "\n",
    "# Load from environment variables (.env file) if running Locally (GPU setup)\n",
    "# wandb_api_key = os.getenv('WANDB_API_KEY')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ[\"WANDB_API_KEY\"] = wandb_api_key\n",
    "wandb.login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 156
    },
    "id": "yJNOv3cVvJ68",
    "outputId": "0c03623e-6887-49e3-8989-bbe45dfc5d35"
   },
   "outputs": [],
   "source": [
    "# Configure Weights & Biases to record against our project\n",
    "\n",
    "LOG_TO_WANDB = True\n",
    "\n",
    "os.environ[\"WANDB_PROJECT\"] = PROJECT_NAME\n",
    "os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\" if LOG_TO_WANDB else \"end\"\n",
    "os.environ[\"WANDB_WATCH\"] = \"gradients\"\n",
    "\n",
    "if LOG_TO_WANDB:\n",
    "  wandb.init(project=PROJECT_NAME, name=RUN_NAME)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "qJWQ0a3wZ0Bw"
   },
   "source": [
    "## 📥 Load the Tokenizer and Model"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 418,
     "referenced_widgets": [
      "1b88f6d4010f4451a58abe2c46b74f62",
      "139758ba39964f49b65eb67182eef68e",
      "9c138d12dcb644fe9b72bd9eb5d26637",
      "3bf8626162904a15932480ddbcea0ebd",
      "a919a41b53604ccd91331d3f713e1310",
      "5b8cdfe01f9a4c248e3de30442411ad4",
      "e14d38a4c3e04d68ac30d475b0db1a73",
      "dadfd3c2a521420890092be265c0aa50",
      "761e88b179104dbbb6455ba81bd1f833",
      "11f5b4df0c7344ba9e188f4eca82886f",
      "125aa3f0dbd744eb82f8e4de94199736",
      "6ca21586e6fc4a608adedba7889eadb5",
      "023eb92e8a2b4323bfd12582e3c23962",
      "c7c76b9845174e9687107595df27c050",
      "78d4a28e03db4775b6e8e071c0b02d5d",
      "8483c625762c49679877a37ab0ddcef9",
      "1df5f6fe2fc04e60bfcb1f78689824ba",
      "add10c416e334928af303d51dfd745c6",
      "5e9e9dac85014292b94d347cc4bad3fe",
      "d665aa6480624ab697f4e426b51d59de",
      "03cce0d3f3a443fc808915b101576e4b",
      "f15714023f234c39863b34d1a3721a8e",
      "8f7a48d803eb4d2182c9da07af743ac7",
      "74892e7b343d410bbbef60c64a823a9a",
      "d6a70560831144e39dc9762d397f4c90",
      "9b969f7fbcdc491cab71aac42761cd2a",
      "d31f9443d1c646309c7a5e1ec39ffc0e",
      "0f5a81846ab143bebf6ec422cda3f145",
      "f0b05f3f7f37414c9d09470c94e304d7",
      "d18784692c9c4ca99e277e6ed51e2bf1",
      "f58addfac7c3438a90ebf10c88348d56",
      "451deac2eeec45598590579340be0d4b",
      "848e0651caf34ef288cca451e3d11274",
      "5adf041222f843429c3a9f1b99becfed",
      "a4764f36570b4752a1ec4392d2f0146c",
      "511a4c6a898346acac9d98fd3a7cdf2c",
      "26da7435a2614201a9e5b8087749f0e0",
      "6054fa015ae44659beb7473c084c7b5b",
      "3b9fc447a9ae4506a1edaf0fa449d9d5",
      "6acef8f1820545ef90b22d90ac80427d",
      "2a5cbad0b8fd45dc9ee25715b1015aef",
      "86a9428f39be4d65a1e922bd9afb3800",
      "96d919a1a7f14e91b8e6c91d855e36d5",
      "82d7484aa2774015b7ea18d933afa9b6",
      "b9d2d4f2c44a4d7cad2b3803c7f6e7be",
      "9f3a176a6ae6426a8c1567a835da8680",
      "006763d2301f4205a588adf5c19876a0",
      "b44eb6596c3441bbaab288030f953a04",
      "bf91666a0c054c79acb03d2e1bb38c37",
      "f0185f1b4b23445c920a873eb63a9372",
      "8e1ac15b677d4c21ad42ea1dda68fe05",
      "87746d8d6d3d413ebb46b4e12fb74cc8",
      "bb5ea1e92c434a46838f943648de87bd",
      "1abcfcba332b40eb901d1331ed84f9bd",
      "52fa5fcc629742619fa3105f73d90767",
      "1bcc2d5771034c2dbc372031e83a2384",
      "221cfaa2a5db4cf1ac399363c3589025",
      "793f9bdc92a545519dd3279023e4ab50",
      "55e25f5cc12f44f3a39fae501fccd060",
      "59463b5e6286483394dedb602991ac95",
      "fc95344ea44d40f28702360542afcff7",
      "ffb3af537d6c41548ad88027505b04d6",
      "6afcf0f6131d4dddbeda796e9c0c5bc5",
      "93f65b3bc071453f86fe8f0f6c17d8fd",
      "2ac9926ee4644232b43d84cfa95c584d",
      "0c5a7738132b4f0f8b4810333b37c588",
      "99d41ffa37134be9a57fe5e50a59b67d",
      "50e71304ab4f42c29f1994fed9b595b8",
      "76b4b0d63e524eb783429169a25be74e",
      "441cfadbe4b446f4b61391b7be4d2865",
      "6751f0c35b634d7c9b06c4e41f9ff851",
      "6a5dc276bbf64bf9b5a99751068ee228",
      "b3ac6055014642a285435f877d5651f5",
      "e9137600b29c4ecaad4ef8bca5fd5f91",
      "634afb9c1b8c4e29b3ec7b76a1108ae4",
      "6be0ac91035548fbbe778e3d7fd58e7e",
      "e8e9d5c979ac4afba526e38b6d0851be",
      "a4ae8ca9c0e7478fbad3b9ed67bc21a2",
      "faf3a64e316a43ddbac8ba14573c4eb4",
      "a395885e39434f9f98246d0fb1c94c8f",
      "d13552c90ead4804a4d5a21121f25536",
      "c25b94002c2246a9aa7f6ed1e4a22cfa",
      "e3892cf602cb4a49948f26cae1e7644c",
      "bc290a324a7147c5b6a722acb41ed05a",
      "2b556f5aa6324958ac6fe36bddf17909",
      "67c6a0534b3a4345b9c11af1bffdfbf0",
      "d767921bb23c485396282cb79a4d1836",
      "d598468ad8f94146976f70d873f0b56d",
      "b547888cd5494b21911b7d457ab6fbac",
      "28362e43274848109c2624e5668942b0",
      "7a27fc65bc0b44ce9bd959f4be13514d",
      "73bc97e6d9cc4ccd8d134092ce970026",
      "c042bf08ab23410098e6d16e837d19ce",
      "d2930ad2c08748d0883bb77c68acf940",
      "c2a1291730874e8e94232c0d51575f81",
      "cb92871b11a0410eb295cc323e5872a7",
      "150a5ce5d8124b0eb9e44d8715b8b1ab",
      "7a6f05ad1f2e483dbcdca102c66530b0",
      "626a29aee42e4e6d8c18d8ea5889734a",
      "c549ca0548d04a7d8749a0842c4aa62b",
      "958c0ff0f47f4c0fa4e2085f5243d84f",
      "a8171febcac94a4b902ff737592f3f47",
      "22630cdb7d6f4975bc31cc189987573d",
      "2f8a9ccee6ea4cdd8c8c225575cae0ce",
      "e40f81c5c4334accbca947964146d238",
      "d6849da8e89546469188dc047c66ea25",
      "8a67d8a2ac0a4fd7a41aa5c890049525",
      "5bf18445be0e46e087cbcd377ccfffbe",
      "72b2020c9479471681ce0f42898cfe1c",
      "c114fd62eb4b4fdca94654668c8f2374",
      "401580df26fc40abb2b774c3d9684921",
      "e756b825b211476994a69fb65f4bbf7c",
      "b2c26cf10e5a4d4fa8961f5c9cca18ce",
      "c288256c73dd44d08916db4e9cf989f0",
      "250a72e9650845d2b274bc3c157439f8",
      "94281c7e5be049c1a9f3dfa082805133",
      "f004f9f743ae4229aa90c92abba6ded6",
      "bd8ca5b8aaed4809a93f553d5cb4a887",
      "4cec4c2d73de4d52b2143082645536ac",
      "893b96616a0e47bfaa0434e10eca1341",
      "74e7d88dd4894894ac2c16fdfd29233b",
      "9e1f1e4288df407fa03415664dc361d5",
      "81dc3f390b9a49f8b1be5c43580b070d",
      "917a225a9bb74f8ab034dcdcee3c7247",
      "bc6c698857ce4f8eabc1571ba0ff0edf",
      "e9ae1c247ae5409f9da4db84ce71a6e3",
      "55071660223e4022a6a7836572077c0c",
      "8364e661011743af9fd40dabc5a7dfe4",
      "ac65442e0d5e43e2998d7c700573228a",
      "666f3434ae8a495f8ada8fedb50b7051",
      "1977e9f07f104faead7dfcfa8aaed6f2",
      "ebe2257c07f345fea72f162542a45142"
     ]
    },
    "id": "R_O04fKxMMT-",
    "outputId": "29aa1cf7-2a2e-492e-adc9-cd0a5bfb123e"
   },
   "outputs": [],
   "source": [
    "BASE_MODEL = \"meta-llama/Meta-Llama-3.1-8B\"\n",
    "\n",
    "quant_config = BitsAndBytesConfig(\n",
    "    load_in_4bit=True,  # Reduce the precision to 4 bits\n",
    "    bnb_4bit_use_double_quant=True,\n",
    "    bnb_4bit_compute_dtype=torch.bfloat16,\n",
    "    bnb_4bit_quant_type=\"nf4\"\n",
    ")\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "tokenizer.padding_side = \"right\"\n",
    "\n",
    "base_model = AutoModelForCausalLM.from_pretrained(\n",
    "    BASE_MODEL,\n",
    "    quantization_config=quant_config,\n",
    "    device_map=\"auto\",\n",
    ")\n",
    "base_model.generation_config.pad_token_id = tokenizer.pad_token_id\n",
    "\n",
    "print(f\"Memory footprint: {base_model.get_memory_footprint() / 1e6:.1f} MB\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SrCE2Le7RBRj"
   },
   "source": [
    "## ⚙️ Fine-tune our LLaMA 3 8B (4-bit quantized) model with QLoRA\n",
    "- 1. Prepare the Data with a Data Collator\n",
    "- 2. Define the QLoRA Configuration (LoraConfig)\n",
    "- 3. Set the Training Parameters (SFTConfig)\n",
    "- 4. Initialize the Fine-Tuning Trainer (SFTTrainer)\n",
    "- 5. Run Fine-Tuning and Push to Hub"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "9BYO0If4uWys"
   },
   "source": [
    "### 🔄 1. Prepare the Data with a Data Collator\n",
    "\n",
    "We only want the model to learn the price, not the product description. Everything before \"Price is $\" is context, not training target. HuggingFace’s DataCollatorForCompletionOnlyLM handles this masking automatically:\n",
    "\n",
    "1. Tokenizes the response_template (\"Price is $\")\n",
    "2. Finds its token position in each input\n",
    "3. Masks all tokens before it (context)\n",
    "4. Trains the model only on tokens after it (the price)\n",
    "\n",
    "\n",
    "Example:\n",
    "\n",
    "Input: \"Product: Red T-shirt. Price is $12.99\"\n",
    "\n",
    "Masked: \"Product: Red T-shirt. Price is $\" → masked (no loss)\n",
    "\n",
    "\"12.99\" → not masked (model is trained to predict this)\n",
    "\n",
    "So the model learns to generate 12.99 given the context, but isn’t trained to repeat or memorize the description."
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "2omVEaPIVJZa"
   },
   "outputs": [],
   "source": [
    "response_template = \"Price is $\"\n",
    "collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "4DaOeBhyy9eS"
   },
   "source": [
    "### 🧠 2. Define the QLoRA Configuration (LoraConfig)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "0HKuVS_XR3cw"
   },
   "outputs": [],
   "source": [
    "LORA_R = 32\n",
    "LORA_ALPHA = 64\n",
    "TARGET_MODULES = [\"q_proj\", \"v_proj\", \"k_proj\", \"o_proj\"]\n",
    "LORA_DROPOUT = 0.1\n",
    "\n",
    "lora_parameters = LoraConfig(\n",
    "    r=LORA_R,\n",
    "    lora_alpha=LORA_ALPHA,\n",
    "    target_modules=TARGET_MODULES,\n",
    "    lora_dropout=LORA_DROPOUT,\n",
    "    bias=\"none\",\n",
    "    task_type=\"CAUSAL_LM\", # Specifies we're doing causal language modeling\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "uLfFsfNQSBAm"
   },
   "source": [
    "### ⚙️ 3. Set the Training Parameters (SFTConfig)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "7PKXdhPXSJot"
   },
   "outputs": [],
   "source": [
    "# 📦 Training Setup:\n",
    "EPOCHS = 1\n",
    "BATCH_SIZE = 16                     # A100 GPU can go up to 16\n",
    "GRADIENT_ACCUMULATION_STEPS = 2\n",
    "MAX_SEQUENCE_LENGTH = 182          # Max token length per input\n",
    "\n",
    "# ⚙️ Optimization:\n",
    "LEARNING_RATE = 1e-4\n",
    "LR_SCHEDULER_TYPE = 'cosine'\n",
    "WARMUP_RATIO = 0.03\n",
    "OPTIMIZER = \"paged_adamw_32bit\"\n",
    "\n",
    "# 💾 Checkpointing & Logging:\n",
    "SAVE_STEPS = 200        # Checkpoint\n",
    "STEPS = 20              # Log every 20 steps\n",
    "save_total_limit = 10   # Keep latest 10 only\n",
    "\n",
    "\n",
    "LOG_TO_WANDB = True\n",
    "\n",
    "HUB_MODEL_NAME = f\"{HF_USER}/{PROJECT_RUN_NAME}\"\n",
    "\n",
    "train_parameters = SFTConfig(\n",
    "    # Output & Run\n",
    "    output_dir=PROJECT_RUN_NAME,\n",
    "    run_name=RUN_NAME,\n",
    "    dataset_text_field=\"text\",\n",
    "    max_seq_length=MAX_SEQUENCE_LENGTH,\n",
    "\n",
    "    # Training\n",
    "    num_train_epochs=EPOCHS,\n",
    "    per_device_train_batch_size=BATCH_SIZE,\n",
    "    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,\n",
    "    max_steps=-1,\n",
    "    group_by_length=True,\n",
    "\n",
    "    # Evaluation\n",
    "    eval_strategy=\"steps\",\n",
    "    eval_steps=STEPS,\n",
    "    per_device_eval_batch_size=1,\n",
    "\n",
    "    # Optimization\n",
    "    learning_rate=LEARNING_RATE,\n",
    "    lr_scheduler_type=LR_SCHEDULER_TYPE,\n",
    "    warmup_ratio=WARMUP_RATIO,\n",
    "    optim=OPTIMIZER,\n",
    "    weight_decay=0.001,\n",
    "    max_grad_norm=0.3,\n",
    "\n",
    "    # Precision\n",
    "    fp16=False,\n",
    "    bf16=True,\n",
    "\n",
    "    # Logging & Saving\n",
    "    logging_steps=STEPS,            # See loss after each {STEP} batches\n",
    "    save_strategy=\"steps\",\n",
    "    save_steps=SAVE_STEPS,          # Model Checkpointed locally\n",
    "    save_total_limit=save_total_limit,\n",
    "    report_to=\"wandb\" if LOG_TO_WANDB else None,\n",
    "\n",
    "    # Hub\n",
    "    push_to_hub=True,\n",
    "    hub_strategy=\"end\",  # Only push once, at the end\n",
    "    load_best_model_at_end=True, # Loads the best eval_loss checkpoint\n",
    "    metric_for_best_model=\"eval_loss\", # Monitors eval_loss\n",
    "    greater_is_better=False, # Lower eval_loss = better model\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1q-a3LHDSoxQ"
   },
   "source": [
    "### 🧩 4. Initialize the Fine-Tuning Trainer (SFTTrainer)\n",
    "Combining everything"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 290,
     "referenced_widgets": [
      "6753caf741414a4c8fa309978253c8cd",
      "aeade430d57b4338910ad0c3645fd06a",
      "eb7081b71cc14aff9b99dba8f9368def",
      "8eb16171df804d06a02351f74bb28dc4",
      "9d60a205ebda49ca88220cc4eec716ca",
      "d8ff973b90374423b4b5e17a1937111c",
      "4bf3bf107f2c4e28a58387c96916e97f",
      "d66cb8c1829c439095f4691fa32d7b6e",
      "567c8321685045c5a873b3b1edecdc96",
      "96ff596facb94acab611201b4adac13f",
      "de65507ce09a4ef4ad8f28d46d335acc",
      "e40fe92fe9094a58b53f0eeb97d3d629",
      "592615cc81624de5a9934f5671d6c188",
      "fadf75d91df54f49acef3f178ea53ce3",
      "5ccca8ab6cb94a88bb27bd482f7948a9",
      "d74dcc2ef9b8442d9ae99db2a79e0c48",
      "580ebfa370d34426933e8c7389872e2b",
      "1187f05dc99641e9a68d9cf49216c370",
      "7deffbba68ba4f018374bd6bec62dd18",
      "d24cdc40a6a34d6eb0efbfde17505d6f",
      "31d44a308b4b4557934ec887e0b6a817",
      "76112ce6fdc4496dba783451efa28cfd",
      "15a85e4a77484c9392b2e5cb8767b336",
      "4524d775b9034a1f890673a9c005d123",
      "5ab6a6b427f84ec685ac52f6ff0d63b5",
      "427ee9e90a844313989f623aba124498",
      "6d2b7c059e6b42afa955fe01bf38011d",
      "5d821ed8ffe14927be799c4d31043a82",
      "12f9fab59e9849dcb7b3b17c5674580f",
      "dd4a2876db37476fa438e8758c855393",
      "f115f97428764c53ac780131fd75bd17",
      "1a1e0e562a844ed098e97ce8a62695ee",
      "0a7ae7cc902243a5996f730f0fe05cdb",
      "07205ea24c3f4959bf9ebd393f5c921d",
      "723bb8342ac84eedabd91e3eef178967",
      "28714d0cf3d84a48975c8ad31e29691d",
      "dd1d90d76d914839a1dad1cddab2c09f",
      "e2d55edf98784523bcbeaad0cc2be494",
      "d00ecfa9dc44428b989ec1a9deb27eae",
      "ba2717985bc342e9827f8901ef655b00",
      "6669dc8f20e3461f93c95cef7a90b201",
      "29cb36c1943c4e1b9898534aaf32bd37",
      "14a1449c13a14afda16bc7c05b7fd840",
      "259d315eb4584c699b1c738d411eab7e",
      "a4bb13eb7cee4f87b0e3e1a3a1be18e7",
      "14d8a699a92044cda33802d96aaa41a2",
      "d345350fd5ad4a028fbbc45cfc9f6db3",
      "6953210353f840d59457fc54f4f8b829",
      "d6cd9e1196f04ecbba83dc0b446b2c65",
      "9e380ef863204da5863c9b6e7a2c8340",
      "1d1bb803831d46309619f6a0c51c2eeb",
      "6a50aaf7ad304a5aa3f29113121e8fe0",
      "7a573a39c2b245f5a84626d951584f67",
      "a57e66367d4245f6bcd4ad0463535583",
      "d6f3327d39a34ec5a44d976f239a61ce",
      "8f450df9f161409a8102c1f0b63edad8",
      "95d932d12cb8442da17adb8e9782c40c",
      "41c5f295b45f4828a9327b699b85ca01",
      "9e4f3fd6bf7749f88ccd7ba65dd9446f",
      "a8f8cb0d9fb14f30a537977f3d51a2c4",
      "4e9e4ed0f2db4d7ba5a5bb0d00676a0c",
      "1fe2bab9c9aa4de48e6e2512f9a7d0a1",
      "d93ac5affccf404fa3916e7f3dd62943",
      "92346fc65f48493d80198ac6d7adf4d8",
      "647bfb2a24cc44a0adaf69ced8e99213",
      "5c96424cff314aa484e4bc905bcbd761",
      "cec2fcfb30194d5ab8c0a3868bad3598",
      "35df7031c4964cef9c53bba6eabbe91d",
      "e15c772e14264c9889e6dae34015e04b",
      "e85b65cb497c48c2b844ae3e5d9efc60",
      "52c8495d46ca4a3c8c6694a700d05e95",
      "3db6d8a5ce2a40daaae6714807a27997",
      "051d74df7ef1468aa968cac5792e7b00",
      "75838a7c887545ff9fbbf5887a1336bc",
      "59f698c1829148ac90edda008d5c6f69",
      "35921436c69643aab792bd1333c749ef",
      "2dd51cc6033746e1a8def460e5e51ff5",
      "a8a3e5973ee5441087d10dfb17bfa1d6",
      "64c3b3c02e844df6bfd3acf1ee23d765",
      "83016eccdd7f4dedab9d3ea6e6852977",
      "9d4c5a62214f4649b77365349ae4ac88",
      "07cb9756d1814a7ba7fb49cccb2763cb",
      "492454ad524742bd8bb3f5c3d5b37feb",
      "e98053f6b7f045da812088d1e76d3a31",
      "f2aeb3ae99cc4b7ca97fb959df1150ad",
      "f92e18b6ab0147b1b428724f5155ca61",
      "14356b2447e349ee8478478eb231fa81",
      "f244a7e331d941f5a99712dcbc5550ea"
     ]
    },
    "id": "fCwmDmkSATvj",
    "outputId": "2b4adc75-e0db-4e0b-c90b-9f9ff2dfd3c6"
   },
   "outputs": [],
   "source": [
    "# The latest version of trl is showing a warning about labels - please ignore this warning\n",
    "fine_tuning = SFTTrainer(\n",
    "    model=base_model,\n",
    "    train_dataset=train_data,\n",
    "    eval_dataset=val_data,\n",
    "    peft_config=lora_parameters,    # QLoRA config\n",
    "    args=train_parameters,          # SFTConfig\n",
    "    data_collator=collator,\n",
    "    callbacks=[EarlyStoppingCallback(early_stopping_patience=5)] # Early stop if no val improvement for 5 steps\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "vHz6JA5_XJ07"
   },
   "source": [
    "### 🚀 5. Run Fine-Tuning and Push to Hub"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "GfvAxnXPvB7w",
    "outputId": "d351d89a-b3d7-4e2b-fee2-5ba2e929837e"
   },
   "outputs": [],
   "source": [
    "fine_tuning.train()\n",
    "print(f\"✅ Best model pushed to HF Hub: {HUB_MODEL_NAME}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](https://github.com/lisek75/nlp_llms_notebook/blob/main/assets/09_train_eval_loss_steps.png?raw=true)\n",
    "\n",
    "![](https://github.com/lisek75/nlp_llms_notebook/blob/main/assets/09_train_eval_loss_wandb.png?raw=true)\n",
    "\n",
    "This chart shows training loss vs evaluation loss over steps during fine-tuning of Llama 31 8B 4-Bit FT (20K Samples).\n",
    "\n",
    "- Blue line (train/loss): Decreasing overall, with some noise. Final value: 1.8596.\n",
    "- Orange line (eval/loss): Smoother and consistently lower than training loss. Final value: 1.8103.\n",
    "\n",
    "- No overfitting: Eval loss < train loss throughout — a good sign.\n",
    "- Stable convergence: Both curves flatten around step 500, suggesting the model is reaching training stability.\n",
    "- Final eval loss is low, indicating decent generalization to unseen data.\n",
    "\n",
    "This fine-tuning run looks healthy. We can likely push further with more data - 400K run."
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 938
    },
    "id": "32vvrYRVAUNg",
    "outputId": "bb4ab0f6-c390-48f3-a71c-2d259bb0ec0b"
   },
   "outputs": [],
   "source": [
    "if LOG_TO_WANDB:\n",
    "  wandb.finish()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](https://github.com/lisek75/nlp_llms_notebook/blob/main/assets/09_run_summary_qlora_llama.png?raw=true)"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "IyKZ0r38IfT3"
   },
   "source": [
    "Now that our best model is pushed to Hugging Face, let’s put it to the test.\n",
    "\n",
    "🔜 See you in the [next notebook](https://github.com/lisekarimi/lexo/blob/main/09_part7_eval_llama_qlora.ipynb)"
   ],
   "outputs": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "A100",
   "provenance": []
  },
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}